{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 相关库导入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from charm.toolbox.integergroup import RSAGroup\n",
    "from charm.schemes.pkenc.pkenc.paillier99 import Ciphertext, Pai99\n",
    "import keras\n",
    "import numpy as np\n",
    "from keras.datasets import mnist\n",
    "from sklearn .svm import LinearSVC\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import SGDClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据集处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(x_train,y_train),(x_test,y_test) = mnist.load_data() #将训练集和测试集合并\n",
    "x = np.concatenate((x_train,x_test),axis = 0)\n",
    "y = np.concatenate((y_train,y_test),axis = 0)\n",
    "x=x.reshape(70000,784)\n",
    "x= x.astype (' float32')\n",
    "x /= 255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train,X_test,y_train,y_test = train_test_split(x,y,test_size = 0.1,random_state = 666)\n",
    "iterations = 10\n",
    "clients = 11"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 生成公钥和私钥"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file = open(\"consuming times.txt\",\"a\")\n",
    "start = time.time ()\n",
    "group=RSAGroup()\n",
    "pai=Pai99(group)\n",
    "pk,sk = pai.keygen(secparam=128)\n",
    "gen_public_param_time = time.time()-start"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 针对每个客户端获得训练出的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(iterations):\n",
    "# 每个 client 需要用各自的第  份数据训练模型\n",
    "    start = time.time()\n",
    "    for j in range(clients):\n",
    "        models[j].partial_fit(train_x[j][i],train_y[j][i],classes = np.array([0,1,2,3,4,5,6,7,8,9]))\n",
    "    gen_model_time = time.time()-start\n",
    "    local_train_time += gen_model_time\n",
    "    print(i,\" the iteration fit Ok\")\n",
    "    y_predict = models[0].predict (x_test)\n",
    "    print(i,\" the accuracy:\",accuracy_score(y_test,y_predict))\n",
    "    acc_list.append(accuracy_score(y_test,y_predict))\n",
    "    # 保存每个模型的参数\n",
    "    model_weights =[]\n",
    "    # 开始训练模型\n",
    "    for j in range(clients):\n",
    "        model_weights .append (get_1D_modelWeights_SVM(models[j]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 对每个模型进行加密"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time .time ()\n",
    "for j in range(clients):\n",
    "    print(j)\n",
    "    print(\"svm:client\"+\":enc\")\n",
    "    enc_model = Encrypt_model(pai,pk,model_weights[j])\n",
    "    print(\"enc_model len:\",len(enc_model))\n",
    "    enc_models.append(enc_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 聚合计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_sum_model = []\n",
    "for k in range(len(enc_models[0])):\n",
    "    temp_sum=pai.encrypti(pk,0)\n",
    "    for j in range(clients):\n",
    "            temp_sum+= enc_models[j][k]\n",
    "    all_sum_model.append(temp_sum)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型解密。将聚合后的参数传回客户端进行解密"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for j in range(clients):\n",
    "    print(str(j)+\"client\")\n",
    "    temp = Decrypt_list(pai,all_sum_model,sk,pk)\n",
    "    temp = TransformModel_into_float(temp)\n",
    "    temp = np.array(temp)\n",
    "    temp = temp/clients\n",
    "    coefs,intercepts = recover_svm_from_1D(models[j],temp)\n",
    "    models[j].coef_= coefs\n",
    "    models[j].intercept_= intercepts"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.2"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
